# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# This file is part of ByteQC.
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from byteqc.cupbc import scf as pscf
from pyscf.lib import logger
from pyscf.pbc import gto
import numpy
import time
import h5py
import sys
import cupy

from byteqc.lib import Mg
if len(sys.argv) > 1:
    ngpu = eval(sys.argv[1])
else:
    ngpu = cupy.cuda.runtime.getDeviceCount()
Mg.set_gpus(ngpu)
if not isinstance(ngpu, int):
    ngpu = len(ngpu)


class DoubleOutput:
    def __init__(self, name="log_Cu_%dG.txt" % ngpu, isinit=True):
        self.name = name
        self.encoding = sys.stdout.encoding
        if isinit:
            self.init()

    def __del__(self):
        self.restore

    def init(self):
        self.file = open(self.name, "a")
        self.stdout_old = sys.stdout
        sys.stdout = self
        self.time = time.time()
        self.write(time.strftime(
            "Start logging at %m%d-%H:%M:%S\n", time.localtime()))

    def restore(self):
        sys.stdout = self.stdout_old
        self.file.close()

    def difftime(self):
        t2 = time.time()
        dt = t2 - self.time
        if dt < 1:
            return "%.3fs" % dt
        if dt < 60:
            return "%.1fs" % dt
        dt = round(dt)
        strtime = '%ds' % (dt % 60)
        dt //= 60
        strtime = '%dm' % (dt % 60) + strtime
        if dt < 60:
            return strtime
        dt //= 60
        strtime = '%dh' % (dt % 24) + strtime
        if dt < 24:
            return strtime
        dt //= 24
        return '%dd' % dt + strtime

    def write(self, data):
        if data and data.isspace():
            self.file.write(data)
        else:
            str = "[%s]" % self.difftime()
            self.file.write(f"{str}{data}")
        self.stdout_old.write(data)

    def flush(self):
        self.stdout_old.flush()
        self.file.flush()


logout = DoubleOutput(isinit=True)

atom = '''
Cu1  0.0000000000000000  0.0000000000000000 10.0000000000000000
Cu1  2.5685145663822495  0.0000000000000000 10.0000000000000000
Cu1  5.1370291327644972  0.0000000000000000 10.0000000000000000
Cu1  7.7055436991467463  0.0000000000000000 10.0000000000000000
Cu1 10.2740582655289945  0.0000000000000000 10.0000000000000000
Cu1 12.8425728319112427  0.0000000000000000 10.0000000000000000
Cu1  1.2842572831911248  2.2243988644773993 10.0000000000000000
Cu1  3.8527718495733740  2.2243988644773993 10.0000000000000000
Cu1  6.4212864159556222  2.2243988644773993 10.0000000000000000
Cu1  8.9898009823378704  2.2243988644773993 10.0000000000000000
Cu1 11.5583155487201186  2.2243988644773993 10.0000000000000000
Cu1 14.1268301151023667  2.2243988644773993 10.0000000000000000
Cu1  2.5685145663822495  4.4487977289547986 10.0000000000000000
Cu1  5.1370291327644972  4.4487977289547986 10.0000000000000000
Cu1  7.7055436991467463  4.4487977289547986 10.0000000000000000
Cu1 10.2740582655289945  4.4487977289547986 10.0000000000000000
Cu1 12.8425728319112444  4.4487977289547986 10.0000000000000000
Cu1 15.4110873982934926  4.4487977289547986 10.0000000000000000
Cu1  3.8527718495733732  6.6731965934321966 10.0000000000000000
Cu1  6.4212864159556204  6.6731965934321966 10.0000000000000000
Cu1  8.9898009823378704  6.6731965934321966 10.0000000000000000
Cu1 11.5583155487201168  6.6731965934321966 10.0000000000000000
Cu1 14.1268301151023650  6.6731965934321966 10.0000000000000000
Cu1 16.6953446814846167  6.6731965934321966 10.0000000000000000
Cu1  5.1370291327644964  8.8975954579095973 10.0000000000000000
Cu1  7.7055436991467454  8.8975954579095973 10.0000000000000000
Cu1 10.2740582655289945  8.8975954579095973 10.0000000000000000
Cu1 12.8425728319112444  8.8975954579095973 10.0000000000000000
Cu1 15.4110873982934926  8.8975954579095973 10.0000000000000000
Cu1 17.9796019646757408  8.8975954579095973 10.0000000000000000
Cu1  6.4212864159556204 11.1219943223869944 10.0000000000000000
Cu1  8.9898009823378704 11.1219943223869944 10.0000000000000000
Cu1 11.5583155487201186 11.1219943223869944 10.0000000000000000
Cu1 14.1268301151023685 11.1219943223869944 10.0000000000000000
Cu1 16.6953446814846167 11.1219943223869944 10.0000000000000000
Cu1 19.2638592478668649 11.1219943223869944 10.0000000000000000
Cu  1.2842572831911254  0.7414662881591335 12.0971833615141655
Cu  3.8527718495733732  0.7414662881591335 12.0971833615141655
Cu  6.4212864159556222  0.7414662881591335 12.0971833615141655
Cu  8.9898009823378722  0.7414662881591335 12.0971833615141655
Cu 11.5583155487201186  0.7414662881591335 12.0971833615141655
Cu 14.1268301151023685  0.7414662881591335 12.0971833615141655
Cu  2.5685145663822477  2.9658651526365318 12.0971833615141655
Cu  5.1370291327644972  2.9658651526365318 12.0971833615141655
Cu  7.7055436991467463  2.9658651526365318 12.0971833615141655
Cu 10.2740582655289945  2.9658651526365318 12.0971833615141655
Cu 12.8425728319112409  2.9658651526365318 12.0971833615141655
Cu 15.4110873982934908  2.9658651526365318 12.0971833615141655
Cu  3.8527718495733740  5.1902640171139307 12.0971833615141655
Cu  6.4212864159556213  5.1902640171139307 12.0971833615141655
Cu  8.9898009823378704  5.1902640171139307 12.0971833615141655
Cu 11.5583155487201203  5.1902640171139307 12.0971833615141655
Cu 14.1268301151023667  5.1902640171139307 12.0971833615141655
Cu 16.6953446814846167  5.1902640171139307 12.0971833615141655
Cu  5.1370291327644964  7.4146628815913296 12.0971833615141655
Cu  7.7055436991467463  7.4146628815913296 12.0971833615141655
Cu 10.2740582655289927  7.4146628815913296 12.0971833615141655
Cu 12.8425728319112444  7.4146628815913296 12.0971833615141655
Cu 15.4110873982934908  7.4146628815913296 12.0971833615141655
Cu 17.9796019646757408  7.4146628815913296 12.0971833615141655
Cu  6.4212864159556213  9.6390617460687302 12.0971833615141655
Cu  8.9898009823378704  9.6390617460687302 12.0971833615141655
Cu 11.5583155487201186  9.6390617460687302 12.0971833615141655
Cu 14.1268301151023685  9.6390617460687302 12.0971833615141655
Cu 16.6953446814846131  9.6390617460687302 12.0971833615141655
Cu 19.2638592478668649  9.6390617460687302 12.0971833615141655
Cu  7.7055436991467454 11.8634606105461273 12.0971833615141655
Cu 10.2740582655289945 11.8634606105461273 12.0971833615141655
Cu 12.8425728319112427 11.8634606105461273 12.0971833615141655
Cu 15.4110873982934926 11.8634606105461273 12.0971833615141655
Cu 17.9796019646757372 11.8634606105461273 12.0971833615141655
Cu 20.5481165310579890 11.8634606105461273 12.0971833615141655
O  11.5583229389528075  6.6731960734513303 19.3350876935844518
C  11.5583218518554656  6.6731985822560622 18.1786368501778419
'''
a = '''15.4110873982934926    0.0000000000000000    0.0000000000000000
7.7055436991467463   13.3463931868643932    0.0000000000000000
0.0000000000000000    0.0000000000000000   26.2915500845425001
'''
omega = 0.1
kmesh = [4, 4, 5]
basis = {'default': 'gth-dzvp-molopt-sr', 'Cu1': 'gth-szv-molopt-sr'}
xc = 'PBE'
pseudo = 'gth-pbe'
key = "%d%d%d%s%s" % (kmesh[0], kmesh[1], kmesh[2], basis, pseudo)
cell = gto.Cell(atom=atom, basis=basis, a=a, pseudo=pseudo)
cell.verbose = 9
cell.build()

print("norbital:%d nbas:%s nao:%d" %
      (cell.nao * numpy.prod(kmesh), cell.nbas, cell.nao))
scaled_center = None
kpts = cell.make_kpts(kmesh, scaled_center=scaled_center)
log = logger.Logger(cell.stdout, 6)
log.info("kmesh= %s", kmesh)
log.info("kpts = %s", kpts)
max_cycle = 3


def wrap(name, func):
    def g(*arg, **kwarg):
        n = "%s/%s" % (key, name)
        try:
            f = h5py.File("LargeCu.dat", "a")
        except BaseException:
            f = h5py.File("LargeCu.dat", "w")
        if n not in f:
            r = func(*arg, **kwarg)
            f[n] = r
        else:
            print("Reading", name)
            r = f[n][:]
        f.close()
        return r
    return g


def run_rsdf(scf, max_cycle=3, verbose=9):
    cell.verbose = verbose
    mf = scf.KRKS(cell, kpts=kpts).rs_density_fit()

    mf.get_init_guess = wrap('dm', mf.get_init_guess)
    mf.get_hcore = wrap('h1e', mf.get_hcore)
    mf.with_df.get_pp = wrap('get_pp', mf.with_df.get_pp)

    mf.xc = xc
    mf.verbose = verbose
    mf.max_memory = 1100000
    mf.with_df.max_memory = 1100000
    mf.max_cycle = max_cycle
    mf.with_df.omega = omega
    mf.with_df.direct = True
    mf.with_df.ksym = 's1'
    mf.with_df.use_bvk = [True, True]
    mf.kernel()
    return mf.e_tot


egpu = run_rsdf(pscf, max_cycle)
